#include "cwebsocket.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <fcntl.h>
#include <ctype.h>
#include <inttypes.h>
#include <errno.h>

#ifdef WIN32
#define strncasecmp strnicmp

char* strndup(const char* str, size_t len)
{
	if (len == 0)
		len = strlen(str);
	char* rtn = (char*)malloc(len+1);
	memcpy(rtn, str,len);
	rtn[len] = 0;
	return rtn;
}

#endif

extern void _cws_SHA1(uint8_t *hash_out,const char *str,int len);

static inline void _cws_debug(const char *prefix, const void *buffer, size_t len)
{
	const uint8_t *bytes = (const uint8_t *)buffer;
	size_t i;
	if (prefix)
		fprintf(stderr, "%s:", prefix);
	for (i = 0; i < len; i++) {
		uint8_t b = bytes[i];
		if (isprint(b))
			fprintf(stderr, " %#04x(%c)", b, b);
		else
			fprintf(stderr, " %#04x", b);
	}
	if (prefix)
		fprintf(stderr, "\n");
}

static void _cws_encode_base64(const uint8_t *input, const size_t input_len, char *output)
{
	static const char base64_map[66] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";
	size_t i, o;
	uint8_t c;

	for (i = 0, o = 0; i + 3 <= input_len; i += 3) {
		c = (input[i] & (((1 << 6) - 1) << 2)) >> 2;
		output[o++] = base64_map[c];

		c = (input[i] & ((1 << 2) - 1)) << 4;
		c |= (input[i + 1] & (((1 << 4) - 1) << 4)) >> 4;
		output[o++] = base64_map[c];

		c = (input[i + 1] & ((1 << 4) - 1)) << 2;
		c |= (input[i + 2] & (((1 << 2) - 1) << 6)) >> 6;
		output[o++] = base64_map[c];

		c = input[i + 2] & ((1 << 6) - 1);
		output[o++] = base64_map[c];
	}

	if (i + 1 == input_len) {
		c = (input[i] & (((1 << 6) - 1) << 2)) >> 2;
		output[o++] = base64_map[c];

		c = (input[i] & ((1 << 2) - 1)) << 4;
		output[o++] = base64_map[c];

		output[o++] = base64_map[64];
		output[o++] = base64_map[64];
	}
	else if (i + 2 == input_len) {
		c = (input[i] & (((1 << 6) - 1) << 2)) >> 2;
		output[o++] = base64_map[c];

		c = (input[i] & ((1 << 2) - 1)) << 4;
		c |= (input[i + 1] & (((1 << 4) - 1) << 4)) >> 4;
		output[o++] = base64_map[c];

		c = (input[i + 1] & ((1 << 4) - 1)) << 2;
		output[o++] = base64_map[c];

		output[o++] = base64_map[64];
	}
}

static void _cws_get_random(void *buffer, size_t len)
{
	uint8_t *bytes = (uint8_t*)buffer;
	uint8_t *bytes_end = bytes + len;

	for (; bytes < bytes_end; bytes++)
		*bytes = rand() & 0xff;
}

static inline void _cws_trim(const char **p_buffer, size_t *p_len)
{
	const char *buffer = *p_buffer;
	size_t len = *p_len;

	while (len > 0 && isspace(buffer[0])) {
		buffer++;
		len--;
	}

	while (len > 0 && isspace(buffer[len - 1]))
		len--;

	*p_buffer = buffer;
	*p_len = len;
}

static inline bool _cws_header_has_prefix(const char *buffer, const size_t buflen, const char *prefix) {
	const size_t prefixlen = strlen(prefix);
	if (buflen < prefixlen)
		return false;
	return strncasecmp(buffer, prefix, prefixlen) == 0;
}

static inline void _cws_hton(void *mem, uint8_t len)
{
#if __BYTE_ORDER__ != __BIG_ENDIAN
	uint8_t *bytes;
	uint8_t i, mid;

	if (len % 2) return;

	mid = len / 2;
	bytes = mem;
	for (i = 0; i < mid; i++) {
		uint8_t tmp = bytes[i];
		bytes[i] = bytes[len - i - 1];
		bytes[len - i - 1] = tmp;
	}
#endif
}

static inline void _cws_ntoh(void *mem, uint8_t len)
{
#if __BYTE_ORDER__ != __BIG_ENDIAN
	uint8_t *bytes;
	uint8_t i, mid;

	if (len % 2) return;

	mid = len / 2;
	bytes = mem;
	for (i = 0; i < mid; i++) {
		uint8_t tmp = bytes[i];
		bytes[i] = bytes[len - i - 1];
		bytes[len - i - 1] = tmp;
	}
#endif
}


#define ERR(fmt, ...)                                   \
    fprintf(stderr, "ERROR: " fmt "\n", ## __VA_ARGS__)

#define STR_OR_EMPTY(p) (p != NULL ? p : "")

/* Temporary buffer size to use during WebSocket masking.
 * stack-allocated
 */
#define CWS_MASK_TMPBUF_SIZE 4096

enum cws_opcode {
    CWS_OPCODE_CONTINUATION = 0x0,
    CWS_OPCODE_TEXT = 0x1,
    CWS_OPCODE_BINARY = 0x2,
    CWS_OPCODE_CLOSE = 0x8,
    CWS_OPCODE_PING = 0x9,
    CWS_OPCODE_PONG = 0xa,
};

static bool cws_opcode_is_control(enum cws_opcode opcode) {
    switch (opcode) {
    case CWS_OPCODE_CONTINUATION:
    case CWS_OPCODE_TEXT:
    case CWS_OPCODE_BINARY:
        return false;
    case CWS_OPCODE_CLOSE:
    case CWS_OPCODE_PING:
    case CWS_OPCODE_PONG:
        return true;
    }

    return true;
}

static bool cws_close_reason_is_valid(enum cws_close_reason r) {
    switch (r) {
    case CWS_CLOSE_REASON_NORMAL:
    case CWS_CLOSE_REASON_GOING_AWAY:
    case CWS_CLOSE_REASON_PROTOCOL_ERROR:
    case CWS_CLOSE_REASON_UNEXPECTED_DATA:
    case CWS_CLOSE_REASON_INCONSISTENT_DATA:
    case CWS_CLOSE_REASON_POLICY_VIOLATION:
    case CWS_CLOSE_REASON_TOO_BIG:
    case CWS_CLOSE_REASON_MISSING_EXTENSION:
    case CWS_CLOSE_REASON_SERVER_ERROR:
    case CWS_CLOSE_REASON_IANA_REGISTRY_START:
    case CWS_CLOSE_REASON_IANA_REGISTRY_END:
    case CWS_CLOSE_REASON_PRIVATE_START:
    case CWS_CLOSE_REASON_PRIVATE_END:
        return true;
    case CWS_CLOSE_REASON_NO_REASON:
    case CWS_CLOSE_REASON_ABRUPTLY:
        return false;
    }

    if (r >= CWS_CLOSE_REASON_IANA_REGISTRY_START && r <= CWS_CLOSE_REASON_IANA_REGISTRY_END)
        return true;

    if (r >= CWS_CLOSE_REASON_PRIVATE_START && r <= CWS_CLOSE_REASON_PRIVATE_END)
        return true;

    return false;
}

/*
 * WebSocket is a framed protocol in the format:
 *
 *    0                   1                   2                   3
 *    0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
 *   +-+-+-+-+-------+-+-------------+-------------------------------+
 *   |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
 *   |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
 *   |N|V|V|V|       |S|             |   (if payload len==126/127)   |
 *   | |1|2|3|       |K|             |                               |
 *   +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
 *   |     Extended payload length continued, if payload len == 127  |
 *   + - - - - - - - - - - - - - - - +-------------------------------+
 *   |                               |Masking-key, if MASK set to 1  |
 *   +-------------------------------+-------------------------------+
 *   | Masking-key (continued)       |          Payload Data         |
 *   +-------------------------------- - - - - - - - - - - - - - - - +
 *   :                     Payload Data continued ...                :
 *   + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
 *   |                     Payload Data continued ...                |
 *   +---------------------------------------------------------------+
 *
 * See https://tools.ietf.org/html/rfc6455#section-5.2
 */
struct cws_frame_header {
    /* first byte: fin + opcode */
    uint8_t opcode : 4;
    uint8_t _reserved : 3;
    uint8_t fin : 1;

    /* second byte: mask + payload length */
    uint8_t payload_len : 7; /* if 126, uses extra 2 bytes (uint16_t)
                              * if 127, uses extra 8 bytes (uint64_t)
                              * if <=125 is self-contained
                              */
    uint8_t mask : 1; /* if 1, uses 4 extra bytes */
};

struct cws_data {
    CURL *easy;
    struct cws_callbacks cbs;
    struct {
        char *requested;
        char *received;
    } websocket_protocols;
    struct curl_slist *headers;
    char accept_key[29];
    struct {
        struct {
            uint8_t *payload;
            uint64_t used;
            uint64_t total;
            enum cws_opcode opcode;
            bool fin;
        } current;
        struct {
            uint8_t *payload;
            uint64_t used;
            uint64_t total;
            enum cws_opcode opcode;
        } fragmented;

        uint8_t tmpbuf[sizeof(struct cws_frame_header) + sizeof(uint64_t)];
        uint8_t done; /* of tmpbuf, for header */
        uint8_t needed; /* of tmpbuf, for header */
    } recv;
    struct {
        uint8_t *buffer;
        size_t len;
    } send;
    uint8_t dispatching;
    uint8_t pause_flags;
    bool accepted;
    bool upgraded;
    bool connection_websocket;
    bool closed;
    bool deleted;
};

static bool _cws_write(cws_data *priv, const void *buffer, size_t len) {
    /* optimization note: we could grow by some rounded amount (ie:
     * next power-of-2, 4096/pagesize...) and if using
     * priv->send.position, do the memmove() here to free up some
     * extra space without realloc() (see _cws_send_data()).
     */
    //_cws_debug("WRITE", buffer, len);
    uint8_t *tmp =(uint8_t *) realloc(priv->send.buffer, priv->send.len + len);
    if (!tmp)
        return false;
    memcpy(tmp + priv->send.len, buffer, len);
    priv->send.buffer = tmp;
    priv->send.len += len;
    if (priv->pause_flags & CURLPAUSE_SEND) {
        priv->pause_flags &= ~CURLPAUSE_SEND;
        curl_easy_pause(priv->easy, priv->pause_flags);
    }
    return true;
}

/*
 * Mask is:
 *
 *     for i in len:
 *         output[i] = input[i] ^ mask[i % 4]
 *
 * Here a temporary buffer is used to reduce number of "write" calls
 * and pointer arithmetic to avoid counters.
 */
static bool _cws_write_masked(cws_data *priv, const uint8_t mask[4], const void *buffer, size_t len) {
    const uint8_t *itr_begin = (const uint8_t *)buffer;
    const uint8_t *itr = itr_begin;
    const uint8_t *itr_end = itr + len;
    uint8_t tmpbuf[CWS_MASK_TMPBUF_SIZE];

    while (itr < itr_end) {
        uint8_t *o = tmpbuf, *o_end = tmpbuf + sizeof(tmpbuf);
        for (; o < o_end && itr < itr_end; o++, itr++) {
            *o = *itr ^ mask[(itr - itr_begin) & 0x3];
        }
        if (!_cws_write(priv, tmpbuf, o - tmpbuf))
            return false;
    }

    return true;
}

static bool _cws_send(cws_data *priv, enum cws_opcode opcode, const void *msg, size_t msglen) {
	cws_frame_header fh;

	fh.fin = 1; /* TODO review if should fragment over some boundary */
	fh.opcode = opcode;
	fh.mask = 1;
	fh.payload_len = ((msglen > UINT16_MAX) ? 127 :
	(msglen > 125) ? 126 : msglen);

    uint8_t mask[4];

    if (priv->closed) {
        ERR("cannot send data to closed WebSocket connection %p", priv->easy);
        return false;
    }

    _cws_get_random(mask, sizeof(mask));

    if (!_cws_write(priv, &fh, sizeof(fh)))
        return false;

    if (fh.payload_len == 127) {
        uint64_t payload_len = msglen;
        _cws_hton(&payload_len, sizeof(payload_len));
        if (!_cws_write(priv, &payload_len, sizeof(payload_len)))
            return false;
    } else if (fh.payload_len == 126) {
        uint16_t payload_len = msglen;
        _cws_hton(&payload_len, sizeof(payload_len));
        if (!_cws_write(priv, &payload_len, sizeof(payload_len)))
            return false;
    }

    if (!_cws_write(priv, mask, sizeof(mask)))
        return false;

    return _cws_write_masked(priv, mask, msg, msglen);
}

bool cws_send(CURL *easy, bool text, const void *msg, size_t msglen) {
    cws_data *priv;
    char *p = NULL;

    curl_easy_getinfo(easy, CURLINFO_PRIVATE, &p); /* checks for char* */
    if (!p) {
        ERR("not CWS (no CURLINFO_PRIVATE): %p", easy);
        return false;
    }
    priv = (cws_data *)p;

    return _cws_send(priv, text ? CWS_OPCODE_TEXT : CWS_OPCODE_BINARY,
                     msg, msglen);
}

bool cws_ping(CURL *easy, const char *reason, size_t len) {
    cws_data *priv;
    char *p = NULL;

    curl_easy_getinfo(easy, CURLINFO_PRIVATE, &p); /* checks for char* */
    if (!p) {
        ERR("not CWS (no CURLINFO_PRIVATE): %p", easy);
        return false;
    }
    priv = (cws_data *)p;

    if (len == SIZE_MAX) {
        if (reason)
            len = strlen(reason);
        else
            len = 0;
    }

    return _cws_send(priv, CWS_OPCODE_PING, reason, len);
}

bool cws_pong(CURL *easy, const char *reason, size_t len) {
    cws_data *priv;
    char *p = NULL;

    curl_easy_getinfo(easy, CURLINFO_PRIVATE, &p); /* checks for char* */
    if (!p) {
        ERR("not CWS (no CURLINFO_PRIVATE): %p", easy);
        return false;
    }
    priv = (cws_data *)p;

    if (len == SIZE_MAX) {
        if (reason)
            len = strlen(reason);
        else
            len = 0;
    }

    return _cws_send(priv, CWS_OPCODE_PONG, reason, len);
}

static void _cws_cleanup(cws_data *priv) {
    CURL *easy;

    if (priv->dispatching > 0)
        return;

    if (!priv->deleted)
        return;

    easy = priv->easy;

    curl_slist_free_all(priv->headers);

    free(priv->websocket_protocols.requested);
    free(priv->websocket_protocols.received);
    free(priv->send.buffer);
    free(priv->recv.current.payload);
    free(priv->recv.fragmented.payload);
    free(priv);

    curl_easy_cleanup(easy);
}

bool cws_close(CURL *easy, enum cws_close_reason reason, const char *reason_text, size_t reason_text_len) {
    cws_data *priv;
    size_t len;
    uint16_t r;
    bool ret;
    char *p = NULL;

    curl_easy_getinfo(easy, CURLINFO_PRIVATE, &p); /* checks for char* */
    if (!p) {
        ERR("not CWS (no CURLINFO_PRIVATE): %p", easy);
        return false;
    }
    curl_easy_setopt(easy, CURLOPT_TIMEOUT, 2);
    priv = (cws_data *)p;

    if (reason == 0) {
        ret = _cws_send(priv, CWS_OPCODE_CLOSE, NULL, 0);
        priv->closed = true;
        return ret;
    }

    r = reason;
    if (!reason_text)
        reason_text = "";

    if (reason_text_len == SIZE_MAX)
        reason_text_len = strlen(reason_text);

    len = sizeof(uint16_t) + reason_text_len;
    p = (char*)malloc(len);
    memcpy(p, &r, sizeof(uint16_t));
    _cws_hton(p, sizeof(uint16_t));
    if (reason_text_len)
        memcpy(p + sizeof(uint16_t), reason_text, reason_text_len);

    ret = _cws_send(priv, CWS_OPCODE_CLOSE, p, len);
    free(p);
    priv->closed = true;
    return ret;
}

static void _cws_check_accept(cws_data *priv, const char *buffer, size_t len) {
    priv->accepted = false;

    if (len != sizeof(priv->accept_key) - 1) {
        ERR("expected %zd bytes, got %zd '%.*s'",
            sizeof(priv->accept_key) - 1, len, (int)len, buffer);
        return;
    }

    if (memcmp(priv->accept_key, buffer, len) != 0) {
        ERR("invalid accept key '%.*s', expected '%.*s'",
            (int)len, buffer, (int)len, priv->accept_key);
        return;
    }

    priv->accepted = true;
}

static void _cws_check_protocol(cws_data *priv, const char *buffer, size_t len) {
    if (priv->websocket_protocols.received)
        free(priv->websocket_protocols.received);

    priv->websocket_protocols.received = strndup(buffer, len);
}

static void _cws_check_upgrade(cws_data *priv, const char *buffer, size_t len) {
    priv->connection_websocket = false;

    if (len == strlen("websocket") &&
        strncasecmp(buffer, "websocket", len) != 0) {
        ERR("unexpected 'Upgrade: %.*s'. Expected 'Upgrade: websocket'",
            (int)len, buffer);
        return;
    }

    priv->connection_websocket = true;
}

static void _cws_check_connection(cws_data *priv, const char *buffer, size_t len) {
    priv->upgraded = false;

    if (len == strlen("upgrade") &&
        strncasecmp(buffer, "upgrade", len) != 0) {
        ERR("unexpected 'Connection: %.*s'. Expected 'Connection: upgrade'",
            (int)len, buffer);
        return;
    }

    priv->upgraded = true;
}

static size_t _cws_receive_header(const char *buffer, size_t count, size_t nitems, void *data) {
    cws_data *priv =(cws_data *)data;
    size_t len = count * nitems;
    const struct header_checker {
        const char *prefix;
        void (*check)(cws_data *priv, const char *suffix, size_t suffixlen);
    } *itr, header_checkers[] = {
        {"Sec-WebSocket-Accept:", _cws_check_accept},
        {"Sec-WebSocket-Protocol:", _cws_check_protocol},
        {"Connection:", _cws_check_connection},
        {"Upgrade:", _cws_check_upgrade},
        {NULL, NULL}
    };

    if (len == 2 && memcmp(buffer, "\r\n", 2) == 0) {
        long status;

        curl_easy_getinfo(priv->easy, CURLINFO_HTTP_CONNECTCODE, &status);
        if (!priv->accepted) {
            if (priv->cbs.on_close) {
                priv->dispatching++;
                priv->cbs.on_close((void *)priv->cbs.data,
                                   priv->easy,
                                   CWS_CLOSE_REASON_SERVER_ERROR,
                                   "server didn't accept the websocket upgrade",
                                   strlen("server didn't accept the websocket upgrade"));
                priv->dispatching--;
                _cws_cleanup(priv);
            }
            return 0;
        } else {
            if (priv->cbs.on_connect) {
                priv->dispatching++;
                priv->cbs.on_connect((void *)priv->cbs.data,
                                     priv->easy,
                                     STR_OR_EMPTY(priv->websocket_protocols.received));
                priv->dispatching--;
                _cws_cleanup(priv);
            }
            return len;
        }
    }

    if (_cws_header_has_prefix(buffer, len, "HTTP/")) {
        priv->accepted = false;
        priv->upgraded = false;
        priv->connection_websocket = false;
        if (priv->websocket_protocols.received) {
            free(priv->websocket_protocols.received);
            priv->websocket_protocols.received = NULL;
        }
        return len;
    }

    for (itr = header_checkers; itr->prefix != NULL; itr++) {
        if (_cws_header_has_prefix(buffer, len, itr->prefix)) {
            size_t prefixlen = strlen(itr->prefix);
            size_t valuelen = len - prefixlen;
            const char *value = buffer + prefixlen;
            _cws_trim(&value, &valuelen);
            itr->check(priv, value, valuelen);
        }
    }

    return len;
}

static bool _cws_dispatch_validate(cws_data *priv) {
    if (priv->closed && priv->recv.current.opcode != CWS_OPCODE_CLOSE)
        return false;

    if (!priv->recv.current.fin && cws_opcode_is_control(priv->recv.current.opcode))
        ERR("server sent forbidden fragmented control frame opcode=%#x.",
            priv->recv.current.opcode);
    else if (priv->recv.current.opcode == CWS_OPCODE_CONTINUATION && priv->recv.fragmented.opcode == 0)
        ERR("server sent continuation frame after non-fragmentable frame");
    else
        return true;

    cws_close(priv->easy, CWS_CLOSE_REASON_PROTOCOL_ERROR, NULL, 0);
    return false;
}

static void _cws_dispatch(cws_data *priv) {
    if (!_cws_dispatch_validate(priv))
        return;

    switch (priv->recv.current.opcode) {
    case CWS_OPCODE_CONTINUATION:
        if (priv->recv.current.fin) {
            if (priv->recv.fragmented.opcode == CWS_OPCODE_TEXT) {
                const char *str = (const char *)priv->recv.current.payload;
                if (priv->recv.current.used == 0)
                    str = "";
                if (priv->cbs.on_text)
                    priv->cbs.on_text((void *)priv->cbs.data, priv->easy, str, priv->recv.current.used);
            } else if (priv->recv.fragmented.opcode == CWS_OPCODE_BINARY) {
                if (priv->cbs.on_binary)
                    priv->cbs.on_binary((void *)priv->cbs.data, priv->easy, priv->recv.current.payload, priv->recv.current.used);
            }
            memset(&priv->recv.fragmented, 0, sizeof(priv->recv.fragmented));
        } else {
            priv->recv.fragmented.payload = priv->recv.current.payload;
            priv->recv.fragmented.used = priv->recv.current.used;
            priv->recv.fragmented.total = priv->recv.current.total;
            priv->recv.current.payload = NULL;
            priv->recv.current.used = 0;
            priv->recv.current.total = 0;
        }
        break;

    case CWS_OPCODE_TEXT:
        if (priv->recv.current.fin) {
            const char *str = (const char *)priv->recv.current.payload;
            if (priv->recv.current.used == 0)
                str = "";
            if (priv->cbs.on_text)
                priv->cbs.on_text((void *)priv->cbs.data, priv->easy, str, priv->recv.current.used);
        } else {
            priv->recv.fragmented.payload = priv->recv.current.payload;
            priv->recv.fragmented.used = priv->recv.current.used;
            priv->recv.fragmented.total = priv->recv.current.total;
            priv->recv.fragmented.opcode = priv->recv.current.opcode;

            priv->recv.current.payload = NULL;
            priv->recv.current.used = 0;
            priv->recv.current.total = 0;
            priv->recv.current.opcode = CWS_OPCODE_CONTINUATION;
            priv->recv.current.fin = 0;
        }
        break;

    case CWS_OPCODE_BINARY:
        if (priv->recv.current.fin) {
            if (priv->cbs.on_binary)
                priv->cbs.on_binary((void *)priv->cbs.data, priv->easy, priv->recv.current.payload, priv->recv.current.used);
        } else {
            priv->recv.fragmented.payload = priv->recv.current.payload;
            priv->recv.fragmented.used = priv->recv.current.used;
            priv->recv.fragmented.total = priv->recv.current.total;
            priv->recv.fragmented.opcode = priv->recv.current.opcode;

            priv->recv.current.payload = NULL;
            priv->recv.current.used = 0;
            priv->recv.current.total = 0;
            priv->recv.current.opcode = CWS_OPCODE_CONTINUATION;
            priv->recv.current.fin = 0;
        }
        break;

    case CWS_OPCODE_CLOSE: {
        enum cws_close_reason reason = CWS_CLOSE_REASON_NO_REASON;
        const char *str = "";
        size_t len = priv->recv.current.used;

        if (priv->recv.current.used >= sizeof(uint16_t)) {
            uint16_t r;
            memcpy(&r, priv->recv.current.payload, sizeof(uint16_t));
            _cws_ntoh(&r, sizeof(r));
            if (!cws_close_reason_is_valid((cws_close_reason)r)) {
                cws_close(priv->easy, CWS_CLOSE_REASON_PROTOCOL_ERROR, "invalid close reason", SIZE_MAX);
                r = CWS_CLOSE_REASON_PROTOCOL_ERROR;
            }
            reason = (cws_close_reason)r;
            str = (const char *)priv->recv.current.payload + sizeof(uint16_t);
            len = priv->recv.current.used - 2;
        } else if (priv->recv.current.used > 0 && priv->recv.current.used < sizeof(uint16_t)) {
            cws_close(priv->easy, CWS_CLOSE_REASON_PROTOCOL_ERROR, "invalid close payload length", SIZE_MAX);
        }

        if (priv->cbs.on_close)
            priv->cbs.on_close((void *)priv->cbs.data, priv->easy, reason, str, len);

        if (!priv->closed) {
            if (reason == CWS_CLOSE_REASON_NO_REASON)
                reason = CWS_CLOSE_REASON_DEFAULT;
            cws_close(priv->easy, reason, str, len);
        }
        break;
    }

    case CWS_OPCODE_PING: {
        const char *str = (const char *)priv->recv.current.payload;
        if (priv->recv.current.used == 0)
            str = "";
        if (priv->cbs.on_ping)
            priv->cbs.on_ping((void *)priv->cbs.data, priv->easy, str, priv->recv.current.used);
        else
            cws_pong(priv->easy, str, priv->recv.current.used);
        break;
    }

    case CWS_OPCODE_PONG: {
        const char *str = (const char *)priv->recv.current.payload;
        if (priv->recv.current.used == 0)
            str = "";
        if (priv->cbs.on_pong)
            priv->cbs.on_pong((void *)priv->cbs.data, priv->easy, str, priv->recv.current.used);
        break;
    }

    default:
        ERR("unexpected WebSocket opcode: %#x.", priv->recv.current.opcode);
        cws_close(priv->easy, CWS_CLOSE_REASON_PROTOCOL_ERROR, "unexpected opcode", SIZE_MAX);
    }
}

static size_t _cws_process_frame(cws_data *priv, const char *buffer, size_t len) {
    size_t used = 0;

    while (len > 0 && priv->recv.done < priv->recv.needed) {
        uint64_t frame_len;

        if (priv->recv.done < priv->recv.needed) {
            size_t todo = priv->recv.needed - priv->recv.done;
            if (todo > len)
                todo = len;
            memcpy(priv->recv.tmpbuf + priv->recv.done, buffer, todo);
            priv->recv.done += todo;
            used += todo;
            buffer += todo;
            len -= todo;
        }

        if (priv->recv.needed != priv->recv.done)
            continue;

        if (priv->recv.needed == sizeof(cws_frame_header)) {
            cws_frame_header fh;

            memcpy(&fh, priv->recv.tmpbuf, sizeof(cws_frame_header));
            priv->recv.current.opcode = (cws_opcode)fh.opcode;
            priv->recv.current.fin = fh.fin;

            if (fh._reserved || fh.mask)
                cws_close(priv->easy, CWS_CLOSE_REASON_PROTOCOL_ERROR, NULL, 0);

            if (fh.payload_len == 126) {
                if (cws_opcode_is_control((cws_opcode)fh.opcode))
                    cws_close(priv->easy, CWS_CLOSE_REASON_PROTOCOL_ERROR, NULL, 0);
                priv->recv.needed += sizeof(uint16_t);
                continue;
            } else if (fh.payload_len == 127) {
                if (cws_opcode_is_control((cws_opcode)fh.opcode))
                    cws_close(priv->easy, CWS_CLOSE_REASON_PROTOCOL_ERROR, NULL, 0);
                priv->recv.needed += sizeof(uint64_t);
                continue;
            } else
                frame_len = fh.payload_len;
        } else if (priv->recv.needed == sizeof(cws_frame_header) + sizeof(uint16_t)) {
            uint16_t plen;

            memcpy(&plen,
                   priv->recv.tmpbuf + sizeof(cws_frame_header),
                   sizeof(plen));
            _cws_ntoh(&plen, sizeof(plen));
            frame_len = plen;
        } else if (priv->recv.needed == sizeof(cws_frame_header) + sizeof(uint64_t)) {
            uint64_t plen;

            memcpy(&plen, priv->recv.tmpbuf + sizeof(cws_frame_header),
                   sizeof(plen));
            _cws_ntoh(&plen, sizeof(plen));
            frame_len = plen;
        } else {
            ERR("needed=%u, done=%u", priv->recv.needed, priv->recv.done);
            abort();
        }

        if (priv->recv.current.opcode == CWS_OPCODE_CONTINUATION) {
            if (priv->recv.fragmented.opcode == 0)
                cws_close(priv->easy, CWS_CLOSE_REASON_PROTOCOL_ERROR, "nothing to continue", SIZE_MAX);
            if (priv->recv.current.payload)
                free(priv->recv.current.payload);

            priv->recv.current.payload = priv->recv.fragmented.payload;
            priv->recv.current.used = priv->recv.fragmented.used;
            priv->recv.current.total = priv->recv.fragmented.total;
            priv->recv.fragmented.payload = NULL;
            priv->recv.fragmented.used = 0;
            priv->recv.fragmented.total = 0;
        } else if (!cws_opcode_is_control(priv->recv.current.opcode) && priv->recv.fragmented.opcode != 0) {
            cws_close(priv->easy, CWS_CLOSE_REASON_PROTOCOL_ERROR, "expected continuation or control frames", SIZE_MAX);
        }

        if (frame_len > 0) {
            void *tmp;

            tmp = realloc(priv->recv.current.payload,
                          priv->recv.current.total + frame_len + 1);
            if (!tmp) {
                cws_close(priv->easy, CWS_CLOSE_REASON_TOO_BIG, NULL, 0);
                ERR("could not allocate memory");
                return CURL_READFUNC_ABORT;
            }
            priv->recv.current.payload = (uint8_t*)tmp;
            priv->recv.current.total += frame_len;
        }
    }

    if (len == 0 && priv->recv.done < priv->recv.needed)
        return used;

    /* fill payload */
    while (len > 0 && priv->recv.current.used < priv->recv.current.total) {
        size_t todo = priv->recv.current.total - priv->recv.current.used;
        if (todo > len)
            todo = len;
        memcpy(priv->recv.current.payload + priv->recv.current.used, buffer, todo);
        priv->recv.current.used += todo;
        used += todo;
        buffer += todo;
        len -= todo;
    }

    if (priv->recv.current.payload)
        priv->recv.current.payload[priv->recv.current.used] = '\0';

    if (len == 0 && priv->recv.current.used < priv->recv.current.total)
        return used;

    priv->dispatching++;

    _cws_dispatch(priv);

    priv->recv.done = 0;
    priv->recv.needed = sizeof(cws_frame_header);
    priv->recv.current.used = 0;
    priv->recv.current.total = 0;

    priv->dispatching--;
    _cws_cleanup(priv);

    return used;
}

static size_t _cws_receive_data(const char *buffer, size_t count, size_t nitems, void *data) {
    cws_data *priv = (cws_data *)data;
    size_t len = count * nitems;
    while (len > 0) {
        size_t used = _cws_process_frame(priv, buffer, len);
        len -= used;
        buffer += used;
    }

    return count * nitems;
}

static size_t _cws_send_data(char *buffer, size_t count, size_t nitems, void *data) {
    cws_data *priv =(cws_data *)data;
    size_t len = count * nitems;
    size_t todo = priv->send.len;

    if (todo == 0) {
        priv->pause_flags |= CURLPAUSE_SEND;
        return CURL_READFUNC_PAUSE;
    }
    if (todo > len)
        todo = len;

    memcpy(buffer, priv->send.buffer, todo);
    if (todo < priv->send.len) {
        /* optimization note: we could avoid memmove() by keeping a
         * priv->send.position, then we just increment that offset.
         *
         * on next _cws_write(), check if priv->send.position > 0 and
         * memmove() to make some space without realloc().
         */
        memmove(priv->send.buffer,
                priv->send.buffer + todo,
                priv->send.len - todo);
    } else {
        free(priv->send.buffer);
        priv->send.buffer = NULL;
    }

    priv->send.len -= todo;
    return todo;
}

static const char *_cws_fill_websocket_key(cws_data *priv, char key_header[44]) {
    uint8_t key[16];
    /* 24 bytes of base24 encoded key
     * + GUID 258EAFA5-E914-47DA-95CA-C5AB0DC85B11
     */
    char buf[61] = "01234567890123456789....258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    uint8_t sha1hash[20];

    _cws_get_random(key, sizeof(key));

    _cws_encode_base64(key, sizeof(key), buf);
    memcpy(key_header + strlen("Sec-WebSocket-Key: "), buf, 24);

	_cws_SHA1(sha1hash,buf, sizeof(buf));
    _cws_encode_base64(sha1hash, sizeof(sha1hash), priv->accept_key);
    priv->accept_key[sizeof(priv->accept_key) - 1] = '\0';

    return key_header;
}

CURL *cws_new(const char *url, const char *websocket_protocols, const cws_callbacks *callbacks) {
    CURL *easy;
    cws_data *priv;
    char key_header[] = "Sec-WebSocket-Key: 01234567890123456789....";
    char *tmp = NULL;
    const curl_version_info_data *cver = curl_version_info(CURLVERSION_NOW);

    if (cver->version_num < 0x073202)
        ERR("CURL version '%s'. At least '7.50.2' is required for WebSocket to work reliably", cver->version);

    if (!url)
        return NULL;

    easy = curl_easy_init();
    if (!easy)
        return NULL;

    priv = (cws_data *)calloc(1, sizeof(cws_data));
    priv->easy = easy;
    curl_easy_setopt(easy, CURLOPT_PRIVATE, priv);
    curl_easy_setopt(easy, CURLOPT_HEADERFUNCTION, _cws_receive_header);
    curl_easy_setopt(easy, CURLOPT_HEADERDATA, priv);
    curl_easy_setopt(easy, CURLOPT_WRITEFUNCTION, _cws_receive_data);
    curl_easy_setopt(easy, CURLOPT_WRITEDATA, priv);
    curl_easy_setopt(easy, CURLOPT_READFUNCTION, _cws_send_data);
    curl_easy_setopt(easy, CURLOPT_READDATA, priv);

    if (callbacks)
        priv->cbs = *callbacks;

    priv->recv.needed = sizeof(cws_frame_header);
    priv->recv.done = 0;

    /* curl doesn't support ws:// or wss:// scheme, rewrite to http/https */
    if (strncmp(url, "ws://", strlen("ws://")) == 0) {
        tmp = (char*)malloc(strlen(url) - strlen("ws://") + strlen("http://") + 1);
        memcpy(tmp, "http://", strlen("http://"));
        memcpy(tmp + strlen("http://"),
               url + strlen("ws://"),
               strlen(url) - strlen("ws://") + 1);
        url = tmp;
    } else if (strncmp(url, "wss://", strlen("wss://")) == 0) {
        tmp =(char*) malloc(strlen(url) - strlen("wss://") + strlen("https://") + 1);
        memcpy(tmp, "https://", strlen("https://"));
        memcpy(tmp + strlen("https://"),
               url + strlen("wss://"),
               strlen(url) - strlen("wss://") + 1);
        url = tmp;
    }
    curl_easy_setopt(easy, CURLOPT_URL, url);
    free(tmp);

    /*
     * BEGIN: work around CURL to get WebSocket:
     *
     * WebSocket must be HTTP/1.1 GET request where we must keep the
     * "send" part alive without any content-length and no chunked
     * encoding and the server answer is 101-upgrade.
     */
    curl_easy_setopt(easy, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_1_1);
    /* Use CURLOPT_UPLOAD=1 to force "send" even with a GET request,
     * however it will set HTTP request to PUT
     */
    curl_easy_setopt(easy, CURLOPT_UPLOAD, 1L);
    /*
     * Then we manually override the string sent to be "GET".
     */
    curl_easy_setopt(easy, CURLOPT_CUSTOMREQUEST, "GET");
    /*
     * CURLOPT_UPLOAD=1 with HTTP/1.1 implies:
     *     Expect: 100-continue
     * but we don't want that, rather 101. Then force: 101.
     */
    priv->headers = curl_slist_append(priv->headers, "Expect: 101");
    /*
     * CURLOPT_UPLOAD=1 without a size implies in:
     *     Transfer-Encoding: chunked
     * but we don't want that, rather unmodified (raw) bites as we're
     * doing the websockets framing ourselves. Force nothing.
     */
    priv->headers = curl_slist_append(priv->headers, "Transfer-Encoding:");
    /* END: work around CURL to get WebSocket. */

    /* regular mandatory WebSockets headers */
    priv->headers = curl_slist_append(priv->headers, "Connection: Upgrade");
    priv->headers = curl_slist_append(priv->headers, "Upgrade: websocket");
    priv->headers = curl_slist_append(priv->headers, "Sec-WebSocket-Version: 13");
    /* Sec-WebSocket-Key: <24-bytes-base64-of-random-key> */
    priv->headers = curl_slist_append(priv->headers, _cws_fill_websocket_key(priv, key_header));

    if (websocket_protocols) {
        char *tmp = (char *)malloc(strlen("Sec-WebSocket-Protocol: ") +
                           strlen(websocket_protocols) + 1);
        memcpy(tmp,
               "Sec-WebSocket-Protocol: ",
               strlen("Sec-WebSocket-Protocol: "));
        memcpy(tmp + strlen("Sec-WebSocket-Protocol: "),
               websocket_protocols,
               strlen(websocket_protocols) + 1);

        priv->headers = curl_slist_append(priv->headers, tmp);
        free(tmp);
        priv->websocket_protocols.requested = strdup(websocket_protocols);
    }

    curl_easy_setopt(easy, CURLOPT_HTTPHEADER, priv->headers);

    return easy;
}

void cws_free(CURL *easy) {
    cws_data *priv;
    char *p = NULL;

    curl_easy_getinfo(easy, CURLINFO_PRIVATE, &p); /* checks for char* */
    if (!p)
        return;
    priv = (cws_data *)p;

    priv->deleted = true;
    _cws_cleanup(priv);
}


///////////////////////////////////////////////////////////////////////////////////////////
//sha1
typedef struct
{
	uint32_t state[5];
	uint32_t count[2];
	unsigned char buffer[64];
} _cws_SHA1_CTX;

#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits))))

/* blk0() and blk() perform the initial expand. */
/* I got the idea of expanding during the round function from SSLeay */
#if BYTE_ORDER == LITTLE_ENDIAN
#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \
    |(rol(block->l[i],8)&0x00FF00FF))
#elif BYTE_ORDER == BIG_ENDIAN
#define blk0(i) block->l[i]
#else
#error "Endianness not defined!"
#endif
#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \
    ^block->l[(i+2)&15]^block->l[i&15],1))

/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */
#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30);
#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30);
#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30);


/* Hash a single 512-bit block. This is the core of the algorithm. */

void _cws_SHA1Transform(
	uint32_t state[5],
	const unsigned char buffer[64]
)
{
	uint32_t a, b, c, d, e;

	typedef union
	{
		unsigned char c[64];
		uint32_t l[16];
	} CHAR64LONG16;

#ifdef SHA1HANDSOFF
	CHAR64LONG16 block[1];      /* use array to appear as a pointer */

	memcpy(block, buffer, 64);
#else
	/* The following had better never be used because it causes the
	* pointer-to-const buffer to be cast into a pointer to non-const.
	* And the result is written through.  I threw a "const" in, hoping
	* this will cause a diagnostic.
	*/
	CHAR64LONG16 *block = (CHAR64LONG16 *)buffer;
#endif
	/* Copy context->state[] to working vars */
	a = state[0];
	b = state[1];
	c = state[2];
	d = state[3];
	e = state[4];
	/* 4 rounds of 20 operations each. Loop unrolled. */
	R0(a, b, c, d, e, 0);
	R0(e, a, b, c, d, 1);
	R0(d, e, a, b, c, 2);
	R0(c, d, e, a, b, 3);
	R0(b, c, d, e, a, 4);
	R0(a, b, c, d, e, 5);
	R0(e, a, b, c, d, 6);
	R0(d, e, a, b, c, 7);
	R0(c, d, e, a, b, 8);
	R0(b, c, d, e, a, 9);
	R0(a, b, c, d, e, 10);
	R0(e, a, b, c, d, 11);
	R0(d, e, a, b, c, 12);
	R0(c, d, e, a, b, 13);
	R0(b, c, d, e, a, 14);
	R0(a, b, c, d, e, 15);
	R1(e, a, b, c, d, 16);
	R1(d, e, a, b, c, 17);
	R1(c, d, e, a, b, 18);
	R1(b, c, d, e, a, 19);
	R2(a, b, c, d, e, 20);
	R2(e, a, b, c, d, 21);
	R2(d, e, a, b, c, 22);
	R2(c, d, e, a, b, 23);
	R2(b, c, d, e, a, 24);
	R2(a, b, c, d, e, 25);
	R2(e, a, b, c, d, 26);
	R2(d, e, a, b, c, 27);
	R2(c, d, e, a, b, 28);
	R2(b, c, d, e, a, 29);
	R2(a, b, c, d, e, 30);
	R2(e, a, b, c, d, 31);
	R2(d, e, a, b, c, 32);
	R2(c, d, e, a, b, 33);
	R2(b, c, d, e, a, 34);
	R2(a, b, c, d, e, 35);
	R2(e, a, b, c, d, 36);
	R2(d, e, a, b, c, 37);
	R2(c, d, e, a, b, 38);
	R2(b, c, d, e, a, 39);
	R3(a, b, c, d, e, 40);
	R3(e, a, b, c, d, 41);
	R3(d, e, a, b, c, 42);
	R3(c, d, e, a, b, 43);
	R3(b, c, d, e, a, 44);
	R3(a, b, c, d, e, 45);
	R3(e, a, b, c, d, 46);
	R3(d, e, a, b, c, 47);
	R3(c, d, e, a, b, 48);
	R3(b, c, d, e, a, 49);
	R3(a, b, c, d, e, 50);
	R3(e, a, b, c, d, 51);
	R3(d, e, a, b, c, 52);
	R3(c, d, e, a, b, 53);
	R3(b, c, d, e, a, 54);
	R3(a, b, c, d, e, 55);
	R3(e, a, b, c, d, 56);
	R3(d, e, a, b, c, 57);
	R3(c, d, e, a, b, 58);
	R3(b, c, d, e, a, 59);
	R4(a, b, c, d, e, 60);
	R4(e, a, b, c, d, 61);
	R4(d, e, a, b, c, 62);
	R4(c, d, e, a, b, 63);
	R4(b, c, d, e, a, 64);
	R4(a, b, c, d, e, 65);
	R4(e, a, b, c, d, 66);
	R4(d, e, a, b, c, 67);
	R4(c, d, e, a, b, 68);
	R4(b, c, d, e, a, 69);
	R4(a, b, c, d, e, 70);
	R4(e, a, b, c, d, 71);
	R4(d, e, a, b, c, 72);
	R4(c, d, e, a, b, 73);
	R4(b, c, d, e, a, 74);
	R4(a, b, c, d, e, 75);
	R4(e, a, b, c, d, 76);
	R4(d, e, a, b, c, 77);
	R4(c, d, e, a, b, 78);
	R4(b, c, d, e, a, 79);
	/* Add the working vars back into context.state[] */
	state[0] += a;
	state[1] += b;
	state[2] += c;
	state[3] += d;
	state[4] += e;
	/* Wipe variables */
	a = b = c = d = e = 0;
#ifdef SHA1HANDSOFF
	memset(block, '\0', sizeof(block));
#endif
}


/* SHA1Init - Initialize new context */

void _cws_SHA1Init(
	_cws_SHA1_CTX * context
)
{
	/* SHA1 initialization constants */
	context->state[0] = 0x67452301;
	context->state[1] = 0xEFCDAB89;
	context->state[2] = 0x98BADCFE;
	context->state[3] = 0x10325476;
	context->state[4] = 0xC3D2E1F0;
	context->count[0] = context->count[1] = 0;
}


/* Run your data through this. */

void _cws_SHA1Update(
	_cws_SHA1_CTX * context,
	const unsigned char *data,
	uint32_t len
)
{
	uint32_t i;

	uint32_t j;

	j = context->count[0];
	if ((context->count[0] += len << 3) < j)
		context->count[1]++;
	context->count[1] += (len >> 29);
	j = (j >> 3) & 63;
	if ((j + len) > 63)
	{
		memcpy(&context->buffer[j], data, (i = 64 - j));
		_cws_SHA1Transform(context->state, context->buffer);
		for (; i + 63 < len; i += 64)
		{
			_cws_SHA1Transform(context->state, &data[i]);
		}
		j = 0;
	}
	else
		i = 0;
	memcpy(&context->buffer[j], &data[i], len - i);
}


/* Add padding and return the message digest. */

void _cws_SHA1Final(
	unsigned char digest[20],
	_cws_SHA1_CTX * context
)
{
	unsigned i;

	unsigned char finalcount[8];

	unsigned char c;

#if 0    /* untested "improvement" by DHR */
	/* Convert context->count to a sequence of bytes
	* in finalcount.  Second element first, but
	* big-endian order within element.
	* But we do it all backwards.
	*/
	unsigned char *fcp = &finalcount[8];

	for (i = 0; i < 2; i++)
	{
		uint32_t t = context->count[i];

		int j;

		for (j = 0; j < 4; t >>= 8, j++)
			*--fcp = (unsigned char)t
	}
#else
	for (i = 0; i < 8; i++)
	{
		finalcount[i] = (unsigned char)((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255);      /* Endian independent */
	}
#endif
	c = 0200;
	_cws_SHA1Update(context, &c, 1);
	while ((context->count[0] & 504) != 448)
	{
		c = 0000;
		_cws_SHA1Update(context, &c, 1);
	}
	_cws_SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */
	for (i = 0; i < 20; i++)
	{
		digest[i] = (unsigned char)
			((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255);
	}
	/* Wipe variables */
	memset(context, '\0', sizeof(*context));
	memset(&finalcount, '\0', sizeof(finalcount));
}

void _cws_SHA1(
	uint8_t *hash_out,
	const char *str,
	int len)
{
	_cws_SHA1_CTX ctx;
	unsigned int ii;

	_cws_SHA1Init(&ctx);
	for (ii = 0; ii<len; ii += 1)
		_cws_SHA1Update(&ctx, (const unsigned char*)str + ii, 1);
	_cws_SHA1Final((unsigned char *)hash_out, &ctx);
	hash_out[20] = '\0';
}
